Skip to content

ENH: cov: expose correction and weights parameters#690

Open
bruAristimunha wants to merge 13 commits intodata-apis:mainfrom
bruAristimunha:cov_parameters
Open

ENH: cov: expose correction and weights parameters#690
bruAristimunha wants to merge 13 commits intodata-apis:mainfrom
bruAristimunha:cov_parameters

Conversation

@bruAristimunha
Copy link
Copy Markdown

Resolves #688.

Summary

  • Adds axis, correction, frequency_weights, and weights parameters to xpx.cov, unlocking the degrees-of-freedom and weighted variants that numpy.cov and torch.cov already support.
  • Naming follows array-api conventions (axis, correction) used elsewhere in this library rather than numpy's (rowvar, bias, ddof). The docstring includes a one-to-one mapping for users migrating from numpy.cov.

Design

The delegation moves observations to the last axis via xp.moveaxis, which collapses rowvar out of backend dispatch entirely — only ddof (numpy/cupy/dask/jax) vs correction (torch) differs between branches.

Fallbacks to the generic implementation (_funcs.cov):

  • m.ndim > 2 (batched input, not supported by any native).
  • Non-integer correction (rejected by numpy.cov's ddof).
  • Dask with weights — dask.array.cov forces .compute() on a lazy 0-D scalar via its internal if fact <= 0 check. The generic path stays fully lazy because its weighted branch doesn't compare fact to zero (noted in docstring).

Weighted formula in _funcs.cov matches numpy's (algebraically): c = (m_c · w) @ m_c.T / (v1 - correction · v2 / v1).

Tests

New TestCov cases validate against np.cov as reference:

  • test_correction (integer ddof)
  • test_correction_float (generic-path-only, hand-computed reference)
  • test_axis / test_axis_with_weights / test_axis_out_of_bounds
  • test_frequency_weights / test_weights / test_both_weights
  • test_batch_with_weights

Test plan

  • pytest tests/test_funcs.py::TestCov — 126 passed across numpy, torch, jax, dask, array-api-strict
  • pytest tests/test_funcs.py full — 4263 passed, 0 failed
  • lefthook run pre-commit — ruff, numpydoc, mypy, pyright, typos all green
  • Dask laziness verified — lazy_xp_function(cov) asserts 0 .compute() calls, holds for weighted path via the fallback

Resolves data-apis#688. Adds `axis`, `correction`, `frequency_weights`, and
`weights` to `cov`, giving users control over the degrees-of-freedom
correction and the observation-axis / weighted variants that
`numpy.cov` and `torch.cov` already support.

Naming follows array-api conventions (`axis`, `correction`) rather
than numpy's (`rowvar`, `bias`, `ddof`); the docstring includes a
one-to-one mapping. The delegation moves observations to the last
axis via `xp.moveaxis`, collapsing `rowvar` out of the backend
dispatch — only `ddof` vs `correction` differs between branches.

Dask's native `cov` forces `.compute()` on a lazy scalar when any
weights are given, so weighted dask inputs fall through to the
generic implementation, which is fully lazy.
@betatim
Copy link
Copy Markdown
Member

betatim commented Apr 20, 2026

It looks like the cov you are adding follows the pytorch signature, can you explain a bit why you chose that? In my PR I thought following the Numpy API makes sense because it seems that most libraries use that.

The PR description mentions that other functions in this library already use correction and axis. Which is a good reason to also do it here? Interested in your thinking.

Comment thread src/array_api_extra/_delegation.py
Comment thread src/array_api_extra/_delegation.py
Comment thread src/array_api_extra/_lib/_funcs.py Outdated
@bruAristimunha
Copy link
Copy Markdown
Author

Hey @betatim!

This was a little hard decision that I had to make, but I can be more strict with numpy if you prefer.

I basically looked at what was already implemented on the API array and how they handle the parameter names that I was trying to implement.

Like, for each parameter that I was trying to introduce, I checked how it was made in the past here from numpy to: the bias, the rowvar, the ddof, the fweights, and the aweights.

Basically, for the bias, ddof to become correction, I notice that in the functions xp.var, xp.std, and think xp.sum, they change the default names to the array api specification name.

https://data-apis.org/array-api/latest/API_specification/generated/array_api.var.html
https://data-apis.org/array-api/latest/API_specification/generated/array_api.std.html

There was a discussion on how to use correction instead of bias+ddof on these functions. Here was introduced data-apis/array-api#10, and then, later, they made some interesting discussions here: data-apis/array-api#695; it was @kgryte who led the discussion.

For the case of the rowvar becoming the axis, I just follow the signature of the other functions. seems like the axis was how they followed.

And for the frequency_weights and weights, it was my experience in Pyriemann that made the decisions. I think the only place that I remember using something similar was the statsmodels (freq_weights, var_weights) that uses https://www.statsmodels.org/stable/generated/statsmodels.genmod.generalized_linear_model.GLM.html#statsmodels.genmod.generalized_linear_model.GLM.freq_weights

I think in scikit you guys use sample_weight more, but I can accommodate any request about this.

Copy link
Copy Markdown
Member

@betatim betatim left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is your thinking on validating the weights passed in? Things like checking the shapes make sense, that they are all positive (is this actually required? how does it fit with being lazy?)

@bruAristimunha
Copy link
Copy Markdown
Author

I liked this idea a lot @betatim! I think it will make the check in the library that use api array extra much lighter.

@bruAristimunha
Copy link
Copy Markdown
Author

FYI @qbarthelemy and @agramfort

Comment thread src/array_api_extra/_lib/_funcs.py Outdated
Co-authored-by: Quentin Barthélemy <q.barthelemy@gmail.com>
Comment thread src/array_api_extra/_delegation.py Outdated
bruAristimunha and others added 2 commits April 20, 2026 13:39
Co-authored-by: Quentin Barthélemy <q.barthelemy@gmail.com>
@betatim
Copy link
Copy Markdown
Member

betatim commented Apr 20, 2026

Thanks a lot for the detailed answer in #690 (comment) - I didn't realise there was precedent for using correction in functions like var. I think it makes sense to copy that and use correction for cov as well. Worth making the translation!

What is the "temporary deployed" thing that keeps happening?

@bruAristimunha
Copy link
Copy Markdown
Author

it is not me @betatim, i think it something that @lucascolley is pushing in pushing here: #699

@bruAristimunha
Copy link
Copy Markdown
Author

Happy that you liked the response @betatim :)

I think I addressed all the points from you and @qbarthelemy, can we merge?

@lucascolley
Copy link
Copy Markdown
Member

What is the "temporary deployed" thing that keeps happening?

fixed in bd3652a

@lucascolley lucascolley changed the title ENH: expose correction and weights parameters in cov ENH: cov: expose correction and weights parameters Apr 20, 2026
Copy link
Copy Markdown
Member

@lucascolley lucascolley left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I took an initial look, seems pretty good!

One high level comment @bruAristimunha — could you demonstrate that this works as expected when used in a branch of sklearn? You should be able to change https://github.com/scikit-learn/scikit-learn/blob/06aded051fe6c7c9970b7e13c3669f952a799831/maint_tools/vendor_array_api_extra.sh#L8-L9 to point to this branch and commit hash.

Comment thread src/array_api_extra/_lib/_funcs.py
Comment thread src/array_api_extra/_delegation.py Outdated
Comment thread src/array_api_extra/_delegation.py
Comment thread src/array_api_extra/_delegation.py Outdated
Comment thread src/array_api_extra/_lib/_funcs.py
Comment thread src/array_api_extra/_lib/_funcs.py Outdated
@bruAristimunha
Copy link
Copy Markdown
Author

hey @betatim,

As you have the first covariance PR on scikit, can you help with this small test as requested by @lucascolley?

One high level comment @bruAristimunha — could you demonstrate that this works as expected when used in a branch of sklearn? You should be able to change https://github.com/scikit-learn/scikit-learn/blob/06aded051fe6c7c9970b7e13c3669f952a799831/maint_tools/vendor_array_api_extra.sh#L8-L9 to point to this branch and commit hash.

@bruAristimunha
Copy link
Copy Markdown
Author

hey @lucascolley,

I made in my branch that was built on top of @betatim's work for scikit-learn first covariance, you can check more here: scikit-learn/scikit-learn#33600

@lucascolley
Copy link
Copy Markdown
Member

hey @lucascolley,

I made in my branch that was built on top of @betatim's work for scikit-learn first covariance, you can check more here: scikit-learn/scikit-learn#33600

thanks! Would be great if you could take a look, Tim

Comment thread src/array_api_extra/_lib/_funcs.py Outdated
bruAristimunha and others added 3 commits April 20, 2026 22:41
Co-authored-by: Quentin Barthélemy <q.barthelemy@gmail.com>
Addresses review feedback (kgryte, betatim) that the motivation for
allowing non-integer correction was not obvious from the docstring:
weighted unbiased correction and autocorrelated data both require
fractional values.
Adds tests for the 1-D shape and length checks in the generic cov
path. Raises the diff coverage for this PR from 93.33% to 100%.
@lucascolley lucascolley added this to the 0.10.2 milestone Apr 24, 2026
@bruAristimunha
Copy link
Copy Markdown
Author

hey @lucascolley,

I was wondering, can you please approve the CI for the final test?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

EHN: make covariance more flexible

5 participants